import sys
import time
import torch
import wandb
from datetime import datetime
from tool.util  import init_wandb
from tool.args import get_general_args
from tool.debug_f import is_m, gpu_show
from tool.logger import Logger
from train.mlbase import MLBase
from train.train import Trainer
from train.lr import adjust_lr, warmup_lr
from evaluate.evaluator import Evaluator
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True


class LearningFramework(MLBase):

    def __init__(self, p):
        super().__init__(other=p)
        self.trainer = Trainer(p)
        self.evaluator = Evaluator(p)
        self.log = Logger()
        self.lr = 0

    def __call__(self):
        self.loop_learn()

    def loop_learn(self):
        start_epoch, self.evaluator.best_acc = self.resume()
        for self.epoch in range(start_epoch, args.epochs):
            print(f"== Epoch [{self.epoch}] has been started at [{datetime.now():%Y.%m.%d %H:%M:%S}] ==")
            self.log.reset()
            adjust_lr(args, self.epoch, self.optimizer, self.log)
            loss, ts = self.loop_epoch()
            if not args.crl: self.evaluator(self.epoch)
            self.save_model(loss, self.evaluator.best_acc)
            print('Epoch time {:.2f}'.format(ts))
        print("Top-1 test accuracy: {acc:.1f}".format(acc=self.evaluator.best_acc))

    def loop_epoch(self):
        self.model.train()
        self.trainer.log.reset()
        time1 = end = time.time()
        for it, (x, y) in enumerate(self.tr_dl):
            self.log.update(1, data_ts=time.time() - end)
            warmup_lr(self.args, self.epoch, self.tr_dl, it, self.optimizer, self.log)
            if self.args.crl: x = torch.cat([x[0], x[1]], dim=0)
            x = x.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True)
            bsz = y.shape[0]
            x_uc = self.uc_dl.__next__()[0].cuda(non_blocking=True) \
                   if self.args.uc_dl else None
            loss = self.trainer(x, y, x_uc)
            self.log.update(bsz, loss=loss.item())
            if self.args.reload and loss.abs().item() > 1e5:
                self.reload(loss)
                continue
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.log.update(1, batch_ts=time.time() - end)
            self.print_iter(it)
            end = time.time()
        time2 = time.time()
        return self.log.loss.avg, time2 - time1

    def print_iter(self, it):
        if (it + 1) % self.args.print_freq == 0:
            d_lf, s_lf = self.log.out()
            d_tr, s_tr = self.trainer.log.out()
            d = {**d_lf, **d_tr}; s = s_lf + s_tr
            print('Train: [{0}][{1}/{2}]\t'.format(self.epoch, it + 1,
                                                   len(self.tr_dl)) + s)
            sys.stdout.flush()
            wandb.log(d)
            # print([p for p in self.model.head.parameters()])


if __name__ == '__main__':
    print(datetime.now())
    args = get_general_args()
    init_wandb(args)
    LearningFramework(MLBase(args))()
